# Copyright 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import asyncio
import base64
import gc
import json
import os
import queue
import threading
from io import BytesIO
from typing import Dict, List

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from PIL import Image
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
from vllm.lora.request import LoRARequest
from vllm.utils import random_uuid

from utils.metrics import VllmStatLogger
from utils.vllm_backend_utils import TritonSamplingParams

_VLLM_ENGINE_ARGS_FILENAME = "model.json"
_MULTI_LORA_ARGS_FILENAME = "multi_lora.json"


class TritonPythonModel:
    @classmethod
    def auto_complete_config(cls, auto_complete_model_config):
        # Add inputs/outputs to the model config.
        cls._auto_complete_inputs_and_outputs(auto_complete_model_config)

        # We need to use decoupled transaction policy for saturating
        # vLLM engine for max throughtput.
        # TODO [DLIS:5233]: Allow asynchronous execution to lift this
        # restriction for cases there is exactly a single response to
        # a single request.
        auto_complete_model_config.set_model_transaction_policy(dict(decoupled=True))

        # Disabling batching in Triton, let vLLM handle the batching on its own.
        auto_complete_model_config.set_max_batch_size(0)

        return auto_complete_model_config

    @staticmethod
    def _auto_complete_inputs_and_outputs(auto_complete_model_config):
        # Inputs expected by the backend.
        inputs = [
            {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
            {
                "name": "image",
                "data_type": "TYPE_STRING",
                "dims": [-1],  # can be multiple images as separate elements
                "optional": True,
            },
            {
                "name": "stream",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "sampling_parameters",
                "data_type": "TYPE_STRING",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "exclude_input_in_output",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "return_finish_reason",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "return_cumulative_logprob",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "return_logprobs",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "return_num_input_tokens",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
            {
                "name": "return_num_output_tokens",
                "data_type": "TYPE_BOOL",
                "dims": [1],
                "optional": True,
            },
        ]
        # Outputs expected by the backend.
        outputs = [
            {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
            {"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
            {"name": "cumulative_logprob", "data_type": "TYPE_FP32", "dims": [-1]},
            {"name": "logprobs", "data_type": "TYPE_STRING", "dims": [-1]},
            {"name": "num_input_tokens", "data_type": "TYPE_UINT32", "dims": [1]},
            {"name": "num_output_tokens", "data_type": "TYPE_UINT32", "dims": [-1]},
        ]

        # Collect input and output names from the provided model config.
        config = auto_complete_model_config.as_dict()
        input_names = []
        output_names = []
        for input in config["input"]:
            input_names.append(input["name"])
        for output in config["output"]:
            output_names.append(output["name"])

        # Add missing inputs and outputs to the model config.
        for input in inputs:
            if input["name"] not in input_names:
                auto_complete_model_config.add_input(input)
        for output in outputs:
            if output["name"] not in output_names:
                auto_complete_model_config.add_output(output)

    def initialize(self, args):
        self.args = args
        self.logger = pb_utils.Logger
        self.model_config = json.loads(args["model_config"])
        output_config = pb_utils.get_output_config_by_name(
            self.model_config, "text_output"
        )
        self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

        # Setup vLLM engine health check
        self._enable_health_check = self._get_bool_config_param(
            "ENABLE_VLLM_HEALTH_CHECK"
        )
        self._is_healthy = True

        # Initialize engine arguments
        # TODO: Move this into _init_engine(), after moving check metrics enabled.
        self._init_engine_args()

        # Check if metrics are enabled. The ZMQ process cannot be used when metrics are
        # enabled.
        # TODO: Move the check into _setup_metrics().
        self._enable_metrics = (
            self._get_bool_config_param("REPORT_CUSTOM_METRICS")
            and not self._aync_engine_args.disable_log_stats
        )

        # Starting the vLLM engine and its event thread running the AsyncIO event loop.
        self._init_engine()

        # Setup vLLM metrics
        self._setup_metrics()

        # Starting the response thread. It allows vLLM to keep making progress while
        # response sender(s) are sending responses to server frontend.
        self._response_queue = queue.Queue()
        self._response_thread = threading.Thread(target=self._response_loop)
        self._response_thread.start()

    def _init_engine_args(self):
        # Currently, Triton needs to use decoupled policy for asynchronously
        # forwarding requests to vLLM engine, so assert it.
        self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
            self.model_config
        )
        assert (
            self.using_decoupled
        ), "vLLM Triton backend must be configured to use decoupled model transaction policy"

        engine_args_filepath = os.path.join(
            pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME
        )
        assert os.path.isfile(
            engine_args_filepath
        ), f"'{_VLLM_ENGINE_ARGS_FILENAME}' containing vllm engine args must be provided in '{pb_utils.get_model_dir()}'"
        with open(engine_args_filepath) as file:
            self.vllm_engine_config = json.load(file)

        # Validate device and multi-processing settings are currently set based on model/configs.
        self._validate_device_config()

        # Check for LoRA config and set it up if enabled
        self._setup_lora()

        # Create an AsyncEngineArgs from the config from JSON
        self._aync_engine_args = AsyncEngineArgs(**self.vllm_engine_config)

    def _init_engine(self):
        # Run the engine in a separate thread running the AsyncIO event loop.
        self._llm_engine = None
        self._llm_engine_start_cv = threading.Condition()
        self._llm_engine_shutdown_event = asyncio.Event()
        self._event_thread = threading.Thread(
            target=asyncio.run, args=(self._run_llm_engine(),)
        )
        self._event_thread.start()
        with self._llm_engine_start_cv:
            while self._llm_engine is None:
                self._llm_engine_start_cv.wait()

        # The 'threading.Thread()' will not raise the exception here should the engine
        # failed to start, so the exception is passed back via the engine variable.
        if isinstance(self._llm_engine, Exception):
            e = self._llm_engine
            self.logger.log_error(f"[vllm] Failed to start engine: {e}")
            if self._event_thread is not None:
                self._event_thread.join()
                self._event_thread = None
            raise e

    async def _run_llm_engine(self):
        # Counter to keep track of ongoing request counts.
        self._ongoing_request_count = 0

        try:
            # Start the vLLM engine. The engine lives for the scope of this with
            # statement.
            # TODO: Metrics should work with ZMQ enabled.
            async with build_async_engine_client_from_engine_args(
                engine_args=self._aync_engine_args,
                disable_frontend_multiprocessing=self._enable_metrics,
            ) as engine:
                # Capture the engine event loop and make it visible to other threads.
                self._event_loop = asyncio.get_running_loop()

                # Signal the engine is started and make it visible to other threads.
                with self._llm_engine_start_cv:
                    self._llm_engine = engine
                    self._llm_engine_start_cv.notify_all()

                # Wait for the engine shutdown signal.
                await self._llm_engine_shutdown_event.wait()

                # Wait for the ongoing requests to complete.
                while self._ongoing_request_count > 0:
                    self.logger.log_info(
                        "[vllm] Awaiting remaining {} requests".format(
                            self._ongoing_request_count
                        )
                    )
                    await asyncio.sleep(1)

                # Cancel all tasks in the event loop.
                for task in asyncio.all_tasks(loop=self._event_loop):
                    if task is not asyncio.current_task():
                        task.cancel()
        except Exception as e:
            # Signal and pass the exception back via the engine variable if the engine
            # failed to start. If the engine has started, re-raise the exception.
            with self._llm_engine_start_cv:
                if self._llm_engine is None:
                    self._llm_engine = e
                    self._llm_engine_start_cv.notify_all()
                    return
            raise e

        self._llm_engine = None
        self.logger.log_info("[vllm] Shutdown complete")

    def _validate_device_config(self):
        triton_kind = self.args["model_instance_kind"]
        triton_device_id = int(self.args["model_instance_device_id"])
        triton_instance = f"{self.args['model_name']}_{triton_device_id}"

        # Triton's current definition of KIND_GPU makes assumptions that
        # models only use a single GPU. For multi-GPU models, the recommendation
        # is to specify KIND_MODEL to acknowledge that the model will take control
        # of the devices made available to it.
        # NOTE: Consider other parameters that would indicate multi-GPU in the future.
        tp_size = int(self.vllm_engine_config.get("tensor_parallel_size", 1))
        if tp_size > 1 and triton_kind == "GPU":
            raise ValueError(
                "KIND_GPU is currently for single-GPU models, please specify KIND_MODEL "
                "in the model's config.pbtxt for multi-GPU models"
            )

        # If KIND_GPU is specified, specify the device ID assigned by Triton to ensure that
        # multiple model instances do not oversubscribe the same default device.
        if triton_kind == "GPU" and triton_device_id >= 0:
            self.logger.log_info(
                f"Detected KIND_GPU model instance, explicitly setting GPU device={triton_device_id} for {triton_instance}"
            )
            # vLLM doesn't currently (v0.4.2) expose device selection in the APIs
            os.environ["CUDA_VISIBLE_DEVICES"] = str(triton_device_id)

    def _setup_lora(self):
        self.enable_lora = False

        # Check if `enable_lora` field is in the `model.json`,
        # and if it is, read its contents, which can be string or bool.
        if (
            "enable_lora" in self.vllm_engine_config.keys()
            and str(self.vllm_engine_config["enable_lora"]).lower() == "true"
        ):
            # create Triton LoRA weights repository
            multi_lora_args_filepath = os.path.join(
                pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME
            )
            try:
                with open(multi_lora_args_filepath) as lora_file:
                    lora_repository: Dict[str, str] = json.load(lora_file)
                self.lora_repository = lora_repository
                self.supported_loras: List[str] = list(self.lora_repository.keys())
                self.supported_loras_len = len(self.supported_loras)
                self.enable_lora = True
            except FileNotFoundError:
                raise FileNotFoundError(
                    f"Triton backend cannot find {multi_lora_args_filepath}."
                )

    def _setup_metrics(self):
        self._vllm_metrics = None
        # TODO: Do not read metrics directly from the vLLM engine, read from prometheus
        #       client to allow the use of ZMQ process when metrics are enabled. See
        #       https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245
        if self._enable_metrics:
            try:
                labels = {
                    "model": self.args["model_name"],
                    "version": self.args["model_version"],
                }
                # Add vLLM custom metrics
                vllm_config = self._llm_engine.engine.vllm_config
                self._vllm_metrics = VllmStatLogger(labels, vllm_config, self.logger)
                self._llm_engine.add_logger("triton", self._vllm_metrics)
            except pb_utils.TritonModelException as e:
                if "metrics not supported" in str(e):
                    # Metrics are disabled at the server
                    self.logger.log_info("[vllm] Metrics not supported")
                else:
                    raise e

    def _get_bool_config_param(self, param_name: str) -> bool:
        return (param_name in self.model_config["parameters"]) and (
            self.model_config["parameters"][param_name]["string_value"].lower()
            == "true"
        )

    def _response_loop(self):
        while True:
            item = self._response_queue.get()
            # To signal shutdown a None item will be added to the queue.
            if item is None:
                break
            response_state, response, response_flag = item
            response_sender = response_state["response_sender"]
            try:
                response_sender.send(response, response_flag)
                # Stop checking for cancellation if the last response is generated.
                if not response_state["last_response_generated"]:
                    response_state["is_cancelled"] = response_sender.is_cancelled()
            except Exception as e:
                self.logger.log_error(
                    f"An error occurred while sending a response: {e}"
                )
            finally:
                if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL:
                    self._ongoing_request_count -= 1

    def execute(self, requests):
        if self._enable_health_check and not self._check_health(requests):
            return None
        for request in requests:
            request = self._verify_loras(request)
            if request is not None:
                assert (
                    self._llm_engine_shutdown_event.is_set() is False
                ), "Cannot create tasks after shutdown has been requested"
                coro = self._generate(request)
                asyncio.run_coroutine_threadsafe(coro, self._event_loop)
        return None

    async def _generate(self, request):
        response_sender = request.get_response_sender()
        response_state = {
            "response_sender": response_sender,
            "is_cancelled": False,
            "last_response_generated": False,  # last response ready but not yet sent
        }
        self._ongoing_request_count += 1
        decrement_ongoing_request_count = True
        try:
            request_id = random_uuid()
            (
                prompt,
                stream,
                prepend_input,
                parameters,
                additional_outputs,
            ) = self._get_input_tensors(request)

            sampling_params = TritonSamplingParams.from_dict(parameters, self.logger)
            lora_name = sampling_params.lora_name
            lora_request = None
            if lora_name is not None:
                lora_id = str(self.supported_loras.index(lora_name) + 1)
                lora_int_id = int(lora_id)
                lora_local_path = self.lora_repository[lora_name]
                lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path)

            response_iterator = self._llm_engine.generate(
                prompt, sampling_params, request_id, lora_request=lora_request
            )

            request_output_state = {}
            async for request_output in response_iterator:
                # Cancellation state will be checked by the response loop and written to
                # the response state if streaming. If not streaming, cancellation state
                # needs to be checked here.
                is_cancelled = response_state["is_cancelled"]
                if not stream:
                    is_cancelled = response_sender.is_cancelled()
                if is_cancelled:
                    self.logger.log_info("[vllm] Cancelling the request")
                    await self._llm_engine.abort(request_id)
                    self.logger.log_info("[vllm] Successfully cancelled the request")

                    if stream:
                        # Add cancelled final response to response loop.
                        response_state["last_response_generated"] = True
                        response = pb_utils.InferenceResponse(
                            error=pb_utils.TritonError(
                                message="Request was cancelled",
                                code=pb_utils.TritonError.CANCELLED,
                            )
                        )
                        flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
                        decrement_ongoing_request_count = False
                        self._response_queue.put_nowait(
                            (response_state, response, flags)
                        )

                    break

                # Send each response if streaming.
                if stream:
                    response = self._create_response(
                        request_output_state,
                        request_output,
                        prepend_input=False,
                        additional_outputs=additional_outputs,
                    )
                    flags = 0
                    if request_output.finished:
                        response_state["last_response_generated"] = True
                        flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
                        decrement_ongoing_request_count = False
                    self._response_queue.put_nowait((response_state, response, flags))

            # Send the last response which contains all the outputs if not streaming.
            if not stream:
                response_sender.send(
                    self._create_response(
                        request_output_state={},
                        request_output=request_output,
                        prepend_input=prepend_input,
                        additional_outputs=additional_outputs,
                    ),
                    flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
                )

        except Exception as e:
            self.logger.log_error(f"[vllm] Error generating stream: {e}")
            error = pb_utils.TritonError(f"Error generating stream: {e}")
            text_output_tensor = pb_utils.Tensor(
                "text_output", np.asarray(["N/A"], dtype=self.output_dtype)
            )
            response = pb_utils.InferenceResponse(
                output_tensors=[text_output_tensor], error=error
            )
            response_sender.send(
                response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
            )
            raise e

        finally:
            if decrement_ongoing_request_count:
                self._ongoing_request_count -= 1

    def _get_input_tensors(self, request):
        # prompt
        prompt = pb_utils.get_input_tensor_by_name(request, "text_input").as_numpy()[0]
        if isinstance(prompt, bytes):
            prompt = prompt.decode("utf-8")

        # image
        images = pb_utils.get_input_tensor_by_name(request, "image")
        if images:
            images_vllm = []
            for image_np in images.as_numpy():
                image_b = base64.b64decode(image_np.decode("utf-8"))
                image_rgb = Image.open(BytesIO(image_b)).convert("RGB")
                images_vllm.append(image_rgb)
            if len(images_vllm) > 0:
                prompt = {
                    "prompt": prompt,
                    "multi_modal_data": {"image": images_vllm},
                }

        # stream
        stream = pb_utils.get_input_tensor_by_name(request, "stream")
        if stream:
            stream = stream.as_numpy()[0]
        else:
            stream = False

        # prepend_input / exclude_input_in_output
        prepend_input = pb_utils.get_input_tensor_by_name(
            request, "exclude_input_in_output"
        )
        if prepend_input:
            # When `exclude_input_in_output` is False, we want to prepend input prompt
            # to output, thus prepend_input should be True, and vice versa.
            prepend_input = not prepend_input.as_numpy()[0]
        elif prepend_input is None and stream:
            prepend_input = False
        else:
            prepend_input = True
        if prepend_input and stream:
            raise ValueError(
                "When streaming, `exclude_input_in_output` = False is not allowed."
            )

        # parameters / sampling_parameters
        # An alternative mechanism to receive serialized parameters as an input
        # tensor, because request parameters are not yet supported via BLS.
        sampling_parameters = pb_utils.get_input_tensor_by_name(
            request, "sampling_parameters"
        )
        if sampling_parameters:
            parameters = sampling_parameters.as_numpy()[0].decode("utf-8")
        else:
            parameters = request.parameters()

        # additional outputs
        additional_outputs = {
            "return_finish_reason": None,
            "return_cumulative_logprob": None,
            "return_logprobs": None,
            "return_num_input_tokens": None,
            "return_num_output_tokens": None,
        }
        for tensor_name in additional_outputs.keys():
            tensor = pb_utils.get_input_tensor_by_name(request, tensor_name)
            if tensor:
                tensor = bool(tensor.as_numpy()[0])
            else:
                tensor = False
            additional_outputs[tensor_name] = tensor

        return prompt, stream, prepend_input, parameters, additional_outputs

    def _create_response(
        self, request_output_state, request_output, prepend_input, additional_outputs
    ):
        output_tensors = []

        # text_output
        prepend_prompt = ""
        if "prev_lens_text_output" not in request_output_state:
            # this is the first response
            if prepend_input:
                prepend_prompt = request_output.prompt
            request_output_state["prev_lens_text_output"] = [0] * len(
                request_output.outputs
            )
        prev_lens = request_output_state["prev_lens_text_output"]
        text_output = [
            (prepend_prompt + output.text[prev_len:]).encode("utf-8")
            for output, prev_len in zip(request_output.outputs, prev_lens)
        ]
        request_output_state["prev_lens_text_output"] = [
            len(output.text) for output in request_output.outputs
        ]
        output_tensors.append(
            pb_utils.Tensor(
                "text_output", np.asarray(text_output, dtype=self.output_dtype)
            )
        )

        # finish_reason
        if additional_outputs["return_finish_reason"]:
            finish_reason = [
                str(output.finish_reason) for output in request_output.outputs
            ]
            output_tensors.append(
                pb_utils.Tensor(
                    "finish_reason", np.asarray(finish_reason, dtype=np.object_)
                )
            )

        # cumulative_logprob
        if additional_outputs["return_cumulative_logprob"]:
            cumulative_logprob = [
                output.cumulative_logprob for output in request_output.outputs
            ]
            output_tensors.append(
                pb_utils.Tensor(
                    "cumulative_logprob",
                    np.asarray(cumulative_logprob, dtype=np.float32),
                )
            )

        # logprobs
        # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/sequence.py#L37-L58
        if additional_outputs["return_logprobs"]:
            if "prev_lens_logprobs" not in request_output_state:
                request_output_state["prev_lens_logprobs"] = [0] * len(
                    request_output.outputs
                )
            logprobs = []
            for i in range(len(request_output.outputs)):
                output = request_output.outputs[i]
                if output.logprobs is None:
                    logprobs.append("null".encode("utf-8"))
                    continue
                prev_len = request_output_state["prev_lens_logprobs"][i]
                request_output_state["prev_lens_logprobs"][i] = len(output.logprobs)
                logprobs_py = []
                for logprob_d_vllm in output.logprobs[prev_len:]:
                    logprob_d_py = {}
                    for token_id, logprob_vllm in logprob_d_vllm.items():
                        logprob_d_py[token_id] = {
                            "logprob": logprob_vllm.logprob,
                            "rank": logprob_vllm.rank,
                            "decoded_token": logprob_vllm.decoded_token,
                        }
                    logprobs_py.append(logprob_d_py)
                logprobs.append(json.dumps(logprobs_py).encode("utf-8"))
            output_tensors.append(
                pb_utils.Tensor("logprobs", np.asarray(logprobs, dtype=np.object_))
            )

        # num_input_tokens
        if additional_outputs["return_num_input_tokens"]:
            num_input_tokens = len(request_output.prompt_token_ids)
            output_tensors.append(
                pb_utils.Tensor(
                    "num_input_tokens", np.asarray(num_input_tokens, dtype=np.uint32)
                )
            )

        # num_output_tokens
        if additional_outputs["return_num_output_tokens"]:
            if "prev_lens_num_output_tokens" not in request_output_state:
                request_output_state["prev_lens_num_output_tokens"] = [0] * len(
                    request_output.outputs
                )
            prev_lens = request_output_state["prev_lens_num_output_tokens"]
            num_output_tokens = [
                (len(output.token_ids) - prev_len)
                for output, prev_len in zip(request_output.outputs, prev_lens)
            ]
            request_output_state["prev_lens_num_output_tokens"] = [
                len(output.token_ids) for output in request_output.outputs
            ]
            output_tensors.append(
                pb_utils.Tensor(
                    "num_output_tokens", np.asarray(num_output_tokens, dtype=np.uint32)
                )
            )

        return pb_utils.InferenceResponse(output_tensors=output_tensors)

    def _verify_loras(self, request):
        # We will check if the requested lora exists here, if not we will send a
        # response with `LoRA not found` information. In this way we may avoid
        # further processing.
        verified_request = None
        lora_error = None
        lora_name = None
        parameters_input_tensor = pb_utils.get_input_tensor_by_name(
            request, "sampling_parameters"
        )
        if parameters_input_tensor:
            parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8")
        else:
            parameters = request.parameters()

        lora_name = json.loads(parameters).pop("lora_name", None)
        if lora_name is not None:
            if not self.enable_lora:
                lora_error = pb_utils.TritonError("LoRA feature is not enabled.")
                self.logger.log_info(
                    "[vllm] LoRA is not enabled, please restart the backend with LoRA enabled."
                )
            elif lora_name not in self.supported_loras:
                lora_error = pb_utils.TritonError(
                    f"LoRA {lora_name} is not supported, we currently support {self.supported_loras}"
                )
                self.logger.log_info(f"[vllm] LoRA {lora_name} not found.")

        if lora_error is not None:
            output_tensor = pb_utils.Tensor(
                "text_output",
                np.asarray(["[Error] Unsupported LoRA."], dtype=self.output_dtype),
            )
            response = pb_utils.InferenceResponse(
                output_tensors=[output_tensor], error=lora_error
            )
            response_sender = request.get_response_sender()
            response_sender.send(
                response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL
            )
        else:
            verified_request = request
        return verified_request

    def _check_health(self, requests):
        coro = self._llm_engine.check_health()
        future = asyncio.run_coroutine_threadsafe(coro, self._event_loop)
        try:
            future.result()
        except Exception as e:
            self.logger.log_error(
                f"[vllm] Engine is not healthy and model will be unloaded: {e}"
            )
            pb_utils.unload_model(self.model_config["name"])  # non-blocking
            self._is_healthy = False
        if not self._is_healthy:
            for request in requests:
                request.get_response_sender().send(
                    pb_utils.InferenceResponse(
                        error=pb_utils.TritonError(
                            message="Model is unavailable due to unhealthy vLLM engine",
                            code=pb_utils.TritonError.UNAVAILABLE,
                        )
                    ),
                    flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL,
                )
        return self._is_healthy

    def finalize(self):
        self.logger.log_info("[vllm] Issuing finalize to vllm backend")
        self._event_loop.call_soon_threadsafe(self._llm_engine_shutdown_event.set)

        # Shutdown the event thread.
        if self._event_thread is not None:
            self._event_thread.join()
            self._event_thread = None

        # Shutdown the response thread.
        self._response_queue.put(None)
        if self._response_thread is not None:
            self._response_thread.join()
            self._response_thread = None

        # Shutdown the metrics thread.
        if self._vllm_metrics is not None:
            self._vllm_metrics.finalize()

        # When using parallel tensors, the stub process may not shutdown due to
        # unreleased references, so manually run the garbage collector once.
        self.logger.log_info("[vllm] Running Garbage Collector on finalize...")
        gc.collect()
        self.logger.log_info("[vllm] Garbage Collector on finalize... done")
